{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Copyright (c) MONAI Consortium  \n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");  \n",
    "you may not use this file except in compliance with the License.  \n",
    "You may obtain a copy of the License at  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0  \n",
    "Unless required by applicable law or agreed to in writing, software  \n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,  \n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  \n",
    "See the License for the specific language governing permissions and  \n",
    "limitations under the License. \n",
    "\n",
    "# Customize Algorithm Ensemble in Auto3DSeg\n",
    "\n",
    "In this notebook, we will provide a brief example of how to customize your ensemble pipeline by defining a new ensemble class."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, yaml, tqdm]\"\n",
    "%env CUDA_VISIBLE_DEVICES=0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import numpy as np\n",
    "import nibabel as nib\n",
    "import random\n",
    "\n",
    "from copy import deepcopy\n",
    "from pathlib import Path\n",
    "\n",
    "from monai.apps.auto3dseg import (\n",
    "    AlgoEnsemble,\n",
    "    AlgoEnsembleBuilder,\n",
    "    DataAnalyzer,\n",
    "    BundleGen,\n",
    ")\n",
    "\n",
    "from monai.bundle.config_parser import ConfigParser\n",
    "from monai.config import print_config\n",
    "from monai.data import create_test_image_3d\n",
    "from monai.utils.enums import AlgoKeys\n",
    "\n",
    "print_config()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Simulate a special dataset\n",
    "\n",
    "It is well known that AI takes time to train. To provide the \"Hello World!\" experience of Auto3DSeg in this notebook, we will simulate a small dataset and run training only for multiple epochs. Due to the nature of AI, the performance shouldn't be highly expected, but the entire pipeline will be completed within minutes!\n",
    "\n",
    "`sim_datalist` provides the information of the simulated datasets. It lists 12 training and 2 testing images and labels.\n",
    "The training data are split into 3 folds. Each fold will use 8 images to train and 4 images to validate.\n",
    "The size of the dimension is defined by the `sim_dim`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim_datalist = {\n",
    "    \"testing\": [\n",
    "        {\"image\": \"test_image_001.nii.gz\", \"label\": \"test_label_001.nii.gz\"},\n",
    "        {\"image\": \"test_image_002.nii.gz\", \"label\": \"test_label_002.nii.gz\"},\n",
    "    ],\n",
    "    \"training\": [\n",
    "        {\"fold\": 0, \"image\": \"tr_image_001.nii.gz\", \"label\": \"tr_label_001.nii.gz\"},\n",
    "        {\"fold\": 0, \"image\": \"tr_image_002.nii.gz\", \"label\": \"tr_label_002.nii.gz\"},\n",
    "        {\"fold\": 0, \"image\": \"tr_image_003.nii.gz\", \"label\": \"tr_label_003.nii.gz\"},\n",
    "        {\"fold\": 0, \"image\": \"tr_image_004.nii.gz\", \"label\": \"tr_label_004.nii.gz\"},\n",
    "        {\"fold\": 1, \"image\": \"tr_image_005.nii.gz\", \"label\": \"tr_label_005.nii.gz\"},\n",
    "        {\"fold\": 1, \"image\": \"tr_image_006.nii.gz\", \"label\": \"tr_label_006.nii.gz\"},\n",
    "        {\"fold\": 1, \"image\": \"tr_image_007.nii.gz\", \"label\": \"tr_label_007.nii.gz\"},\n",
    "        {\"fold\": 1, \"image\": \"tr_image_008.nii.gz\", \"label\": \"tr_label_008.nii.gz\"},\n",
    "        {\"fold\": 2, \"image\": \"tr_image_009.nii.gz\", \"label\": \"tr_label_009.nii.gz\"},\n",
    "        {\"fold\": 2, \"image\": \"tr_image_010.nii.gz\", \"label\": \"tr_label_010.nii.gz\"},\n",
    "        {\"fold\": 2, \"image\": \"tr_image_011.nii.gz\", \"label\": \"tr_label_011.nii.gz\"},\n",
    "        {\"fold\": 2, \"image\": \"tr_image_012.nii.gz\", \"label\": \"tr_label_012.nii.gz\"},\n",
    "    ],\n",
    "}\n",
    "\n",
    "sim_dim = (64, 64, 64)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate images and labels\n",
    "\n",
    "Now we can use MONAI `create_test_image_3d` and `nib.Nifti1Image` functions to generate the 3D simulated images under the work_dir"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "work_dir = str(Path(\"./ensemble_byoc_work_dir\"))\n",
    "if not os.path.isdir(work_dir):\n",
    "    os.makedirs(work_dir)\n",
    "\n",
    "dataroot_dir = os.path.join(work_dir, \"sim_dataroot\")\n",
    "if not os.path.isdir(dataroot_dir):\n",
    "    os.makedirs(dataroot_dir)\n",
    "\n",
    "datalist_file = os.path.join(work_dir, \"sim_datalist.json\")\n",
    "with open(datalist_file, \"w\") as f:\n",
    "    json.dump(sim_datalist, f)\n",
    "\n",
    "for d in sim_datalist[\"testing\"] + sim_datalist[\"training\"]:\n",
    "    im, seg = create_test_image_3d(\n",
    "        sim_dim[0], sim_dim[1], sim_dim[2], rad_max=10, num_seg_classes=1, random_state=np.random.RandomState(42)\n",
    "    )\n",
    "    image_fpath = os.path.join(dataroot_dir, d[\"image\"])\n",
    "    label_fpath = os.path.join(dataroot_dir, d[\"label\"])\n",
    "    nib.save(nib.Nifti1Image(im, affine=np.eye(4)), image_fpath)\n",
    "    nib.save(nib.Nifti1Image(seg, affine=np.eye(4)), label_fpath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define a new class that inherit the AlgoEnsemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyAlgoEnsemble(AlgoEnsemble):\n",
    "    \"\"\"\n",
    "    Randomly select N models to do ensemble\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, n_models=3):\n",
    "        super().__init__()\n",
    "        self.n_models = n_models\n",
    "\n",
    "    def collect_algos(self):\n",
    "        \"\"\"\n",
    "        collect_algos defines the method to collect the target algos from the self.algos list\n",
    "        \"\"\"\n",
    "        n = len(self.algos)\n",
    "        if self.n_models > n:\n",
    "            raise ValueError(f\"Number of loaded Algo is {n}, but {self.n_models} algos are requested.\")\n",
    "\n",
    "        indexes = list(range(n))\n",
    "        random.shuffle(indexes)\n",
    "        indexes = indexes[0 : self.n_models]\n",
    "        self.algo_ensemble = []\n",
    "        for idx in indexes:\n",
    "            self.algo_ensemble.append(deepcopy(self.algos[idx]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run Auto3DSeg data analyzer, algo generation, and training\n",
    "\n",
    "> NOTE: For demo purposes, below contains a snippet to convert num_epoch to iteration style and override all algorithms with the same training parameters `train_params`.\n",
    "> If users would like to use more than one GPU, they can change the `CUDA_VISIBLE_DEVICES` environment variable, or just remove it to use all available devices.\n",
    "> Users also need to ensure the number of GPUs is not greater than the number that the training dataset can be partitioned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_stats_filename = os.path.join(work_dir, \"datastats.yaml\")\n",
    "da = DataAnalyzer(datalist_file, dataroot_dir, output_path=data_stats_filename)\n",
    "da.get_all_case_stats()\n",
    "\n",
    "input = {\n",
    "    \"modality\": \"MRI\",\n",
    "    \"datalist\": datalist_file,\n",
    "    \"dataroot\": dataroot_dir,\n",
    "}\n",
    "\n",
    "input_cfg = \"input.yaml\"\n",
    "ConfigParser.export_config_file(input, input_cfg)\n",
    "\n",
    "bundle_generator = BundleGen(algo_path=work_dir, data_stats_filename=data_stats_filename, data_src_cfg_name=input_cfg)\n",
    "bundle_generator.generate(work_dir, num_fold=2)\n",
    "history = bundle_generator.get_history()\n",
    "\n",
    "max_epochs = 2\n",
    "\n",
    "train_param = {\n",
    "    \"num_epochs_per_validation\": 1,\n",
    "    \"num_images_per_batch\": 2,\n",
    "    \"num_epochs\": max_epochs,\n",
    "    \"num_warmup_epochs\": 1,\n",
    "}\n",
    "\n",
    "for algo_dict in history:\n",
    "    algo = algo_dict[AlgoKeys.ALGO]\n",
    "    algo.train(train_param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Apply MyAlgoEnsemble\n",
    "\n",
    "As we defined earlier in this notebook, `MyAlgoEnsemble` randomly shuffles the trained models, and picks three for ensemble."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "builder = AlgoEnsembleBuilder(history, input_cfg)\n",
    "builder.set_ensemble_method(MyAlgoEnsemble())\n",
    "ensemble = builder.get_ensemble()\n",
    "preds = ensemble()\n",
    "\n",
    "print(\"The ensemble randomly picks the following models:\")\n",
    "for algo in ensemble.get_algo_ensemble():\n",
    "    print(algo[AlgoKeys.ID])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
